package org.funny.nn.som.algorithm;

import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;

/**
 * 自组织映射神经网络(Self-Organizing Map,SOM)是无监督学习方法中一类重要方法,
 * 因为BPNN中发现在JAVA中用 double[][] 矩阵运算，速度不及 float[][], 然而准确性的差距却难以感知。所以主要运算为float
 * 这里输入为 float[][] 表示N个M维度的向量。 而神经层实现了 线性的float[][] - 线性的神经层，以及float[][][]二维的神经层，并没有实现六边形网格。
 * 原理都是通过 竞争学习，使得神经层能够趋近输入层(自适应)。并且由于和边上节点的关系，能够发现输入层的特征。
 * 两个例子，一个解决旅行商问题，一个为自动分类。
 * @author LinLW
 */
public class Som {

    /**
    * 用于观察网络层的钩子。
    *
    */
    public interface DataWatcher {
        /**
        * 展示网络层在某时刻的状态
        *
        * @param iterations 当前循环第几轮
        * @param finishInputCount 当前完成了第几个input的处理
        * @param network 神经网络数据情况
        */
        void watch(int iterations,int finishInputCount, Object network);
    }
    /**
    * 用于计算距离的钩子
    */
    public interface DistanceCalculator {
        /**
         * 计算输入的某个向量，和神经网络中的某个神经向量的距离
         * @param input 输入向量
         * @param node 神经向量
         * @return float 距离
         */
        float distance(float[] input, float[] node);
    }
    public float getLearningRate() {
        return learningRate;
    }

    public void setLearningRate(float learningRate) {
        this.learningRate = learningRate;
    }

    public float getDecayedRate() {
        return decayedRate;
    }

    public void setDecayedRate(float decayedRate) {
        this.decayedRate = decayedRate;
    }

    public int getIterations() {
        return iterations;
    }

    public void setIterations(int iterations) {
        this.iterations = iterations;
    }

    public int getNetworkSize() {
        return networkSize;
    }

    public void setNetworkSize(int networkSize) {
        this.networkSize = networkSize;
    }

    public float getNeighborRadix() {
        return neighborRadix;
    }

    public void setNeighborRadix(float neighborRadix) {
        this.neighborRadix = neighborRadix;
    }

    public float[][] getInput() {
        return input;
    }

    public void setInput(float[][] input) {
        this.input = input;
    }

    public DistanceCalculator getDistanceCalculator() {
        return distanceCalculator;
    }

    public void setDistanceCalculator(DistanceCalculator distanceCalculator) {
        this.distanceCalculator = distanceCalculator;
    }

    public DataWatcher getDataWatcher() {
        return dataWatcher;
    }

    public void setDataWatcher(DataWatcher dataWatcher) {
        this.dataWatcher = dataWatcher;
    }


    private float learningRate;
    private float decayedRate;
    private int iterations;
    private int networkSize;
    private float neighborRadix;
    private float[][] input;
    private float[][] networkD1;//一维神经网络
    private float[][][] networkD2;//二维神经网络
    private DistanceCalculator distanceCalculator;
    private DataWatcher dataWatcher;

    /**
     * 开始训练，该训练使用一维神经网络
     *
     */
    public void train4tsp() {

        // 产生一个神经网络
        networkD1 = generateNetworkD1(networkSize);

        for(int i=0; i<iterations; i++){
            if(dataWatcher!=null){
                dataWatcher.watch(i,0,networkD1);
            }
            int finish=0;
            for(float[] inputRow:input) {
                int winner_idx = selectClosest(networkD1, inputRow);

                float[] gaussian = getNeighborhood(winner_idx, neighborRadix, networkD1.length);

                for (int n=0; n<networkD1.length; n++) {
                    float[] v = networkD1[n];
                    float ratio = gaussian[n] * learningRate;
                    for(int p=0; p<v.length; p++){
                        v[p]+= ratio * (inputRow[p]-v[p]);
                    }
                }
                if(dataWatcher!=null) {
                    dataWatcher.watch(i, ++finish, networkD1);
                }
            }
            learningRate *= decayedRate;
            neighborRadix *= decayedRate;

        }

    }

    /**
     * 开始训练，该训练使用二维神经网络
     *
     */
    public void train4classifiy() {

        // 产生一个神经网络
        networkD2 = generateNetworkD2(networkSize);

        for(int i=0; i<iterations; i++){
            if(dataWatcher!=null){
                dataWatcher.watch(i,0,networkD2);
            }
            int finish=0;
            for(float[] inputRow:input) {
                int[] winner_idx = selectClosest(networkD2, inputRow);

                float[][] gaussian = getNeighborhood(winner_idx, neighborRadix, networkD2.length);

                for (int n=0; n<networkD2.length; n++) {
                    for (int m=0; m<networkD2.length; m++) {
                        float[] v = networkD2[n][m];
                        float ratio = gaussian[n][m] * learningRate;
                        for (int p = 0; p < v.length; p++) {
                            v[p] += ratio * (inputRow[p] - v[p]);
                        }
                    }
                }
                if(dataWatcher!=null) {
                    dataWatcher.watch(i, ++finish, networkD2);
                }
            }
            learningRate *= decayedRate;
            neighborRadix *= decayedRate;
        }
    }

    /**
     * 根据每个input点，最近的神经网络节点顺序，对input节点重新排序。
     * 这个排序就就是整理出来 旅行商，经过城市的顺序。
     * @return int[]
     */
    public int[] getRoute() {
        int[][] temp=new int[input.length][2];
        for( int i=0;i<input.length;i++){
            float[] city=input[i];
            temp[i]=new int[2];
            temp[i][0]=i;
            temp[i][1]=selectClosest(networkD1,city);
        }

        Arrays.sort(temp, Comparator.comparingInt(row -> row[1]));
        int[] ret=new int[input.length];
        for(int i=0;i<temp.length;i++){
            ret[i]=temp[i][0];
        }
        return ret;

    }

    /**
     * 寻找最近的节点编号
     *
     *
     * @param network 所有的神经网络向量(一维)
     * @param target 目标向量
     * @return: int 位置
     */
    private int selectClosest(float[][] network, float[] target){
        ensureDistanceCalculator();
        double dis=Double.MAX_VALUE;
        int find =-1;
        for(int i=0;i<network.length;i++ ){
            float[] p=network[i];
            double disCur= distanceCalculator.distance(p,target);
            if(disCur<dis){
                dis=disCur;
                find=i;
            }
        }
        return find;
    }
    /**
     * 寻找最近的节点位置
     *
     *
     * @param network 所有的神经网络向量(二维)
     * @param target 目标向量
     * @return: int 位置
     */
    public int[] selectClosest(float[][][] network, float[] target){
        ensureDistanceCalculator();
        double dis=Double.MAX_VALUE;
        int findX =-1;
        int findY =-1;
        for(int i=0;i<network.length;i++ ){
            float[][] q=network[i];
            for(int j=0;j<q.length;j++ ) {
                float[] p = q[j];
                double disCur = distanceCalculator.distance(p, target);
                if (disCur < dis) {
                    dis = disCur;
                    findX = i;
                    findY = j;
                }
            }
        }
        return new int[]{findX,findY};
    }
    /**
     * 如果没有距离计算公式。默认用空间位置距离计算公司代替
     */
    private void ensureDistanceCalculator(){
        if(distanceCalculator==null){
            distanceCalculator=(a,b)->{
                float sum=0.0F;
                for(int i=0;i<a.length;i++){
                    sum+=(a[i]-b[i])*(a[i]-b[i]);
                }
                return (float)Math.sqrt(sum);
            };
        }
    }

    /**
     * 根据正态分布取得各个节点学习率修正(一维神经网络的情况)
     * @param center 中心点位置
     * @param radix 半径
     * @param domain 数据总长度
     * @return float[]
     */
    private float[] getNeighborhood(int center,float radix,int domain){

        float[] ret=new float[domain];
        for(int i=0;i<domain;i++){
            int deltas=Math.abs(center-i);
            int distances=Math.min(deltas,domain-deltas);
            float gaussian =  (float) Math.exp(-(distances*distances)* 1.0 / (2*radix*radix));
            ret[i]=gaussian;
        }
        return ret;
    }
    /**
     * 根据正态分布取得各个节点学习率修正(二维神经网络的情况)
     * @param center 中心点位置
     * @param radix 半径
     * @param domain 数据总长度(边长)
     * @return float[][]
     */
    private float[][] getNeighborhood(int[] center,float radix,int domain){

        float[][] ret=new float[domain][];
        for(int i=0;i<domain;i++){
            ret[i]=new float[domain];
            for(int j=0;j<domain;j++) {
                int deltasI = Math.abs(center[0] - i);
                int distancesI = Math.min(deltasI, domain - deltasI);
                int deltasJ = Math.abs(center[1] - j);
                int distancesJ = Math.min(deltasJ, domain - deltasJ);
//                int distances=Math.min(distancesI,distancesJ);
                float gaussian = (float) Math.exp(-(distancesI * distancesI+ distancesJ*distancesJ) * 1.0 / (2 * radix * radix));
                ret[i][j] = gaussian;
            }
        }
        return ret;
    }

    /**
     * 生成初始化的，一维的神经网络
     *
     * @param size 总长度
     * @return: float[][]
     */
    private float[][] generateNetworkD1(int size){
        int d=input[0].length;
        float[][] ret=new float[size][];
        for(int i=0;i<size;i++){
            float[]row= ret[i]=new float[d];
            for(int k=0;k<d;k++){
                row[k]=RANDOM.nextFloat();
            }
        }
        return ret;
    }

    /**
     * 生成初始化的，二维的神经网络
     *
     * @param size 总长度(边长)
     * @return: float[][][]
     */
    private float[][][] generateNetworkD2(int size){
        int d=input[0].length;
        float[][][] ret=new float[size][][];
        for(int i=0;i<size;i++){
            ret[i]=new float[size][];
            for(int j=0;j<size;j++) {
                float[] row = ret[i][j] = new float[d];
                for (int k = 0; k < d; k++) {
                    row[k] = RANDOM.nextFloat();
                }
            }
        }
        return ret;
    }
    private static Random RANDOM=new Random();



}